{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "fa28cb38-25b9-48da-901a-2295d07ab7c8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x79b2fc07ac90>"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from torch.autograd.functional import jvp\n",
    "torch.manual_seed(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "929e855b-1734-4c75-ac4b-5223148db6e0",
   "metadata": {},
   "source": [
    "- 底层调用的还是 `torch.autograd.grad`\n",
    "    - `torch.autograd.grad`定义在计算图上\n",
    "    - `torch.autograd.functional.*` functional 定义在函数上，\n",
    "        - 函数在某点处的梯度\n",
    "- Parameters\n",
    "    - func: 函数微分；\n",
    "    - inputs\n",
    "    - v: v.shape == inputs.shape,\n",
    "- Returns\n",
    "    - func_output\n",
    "    - jvp"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d2dacee-bed3-4def-8112-5a7d081de906",
   "metadata": {},
   "source": [
    "## single input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "3bf59658-c0e1-4c35-b372-1d7eb401eea3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(x):\n",
    "    return 2*x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "b4af4241-d294-4901-8ffd-ea7635588c74",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([1.7645, 1.8300]), tensor([2., 4.]))"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.rand(2)\n",
    "v = torch.tensor([1., 2.])\n",
    "jvp(f, x, v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "9415e9de-79c1-46a2-8e7f-9ed8e87d855a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.8823, 0.9150])"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "0d198cab-f184-4a05-bc59-26578300ab0f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[2., 0.],\n",
       "        [0., 2.]])"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.autograd.functional.jacobian(f, x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be8ccade-63df-4712-9739-2de563405b25",
   "metadata": {},
   "source": [
    "### matrix input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "499fb569-78a8-4562-9b77-f5f5c2f675cd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([7.2083, 8.5064, 7.8338, 6.0426]),\n",
       " tensor([7.2083, 8.5064, 7.8338, 6.0426]))"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def exp_reducer(x):\n",
    "    return x.exp().sum(dim=1)\n",
    "inputs = torch.rand(4, 4)\n",
    "v = torch.ones(4, 4)\n",
    "jvp(exp_reducer, inputs, v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "1d5cd93b-8e3d-44c6-bedb-b74064e10c0c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.2925, 2.2114, 2.5620, 1.1425],\n",
       "        [2.5462, 1.8105, 2.3855, 1.7642],\n",
       "        [2.0982, 1.5363, 2.4241, 1.7752],\n",
       "        [1.3055, 1.8728, 1.3095, 1.5548]])"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs.exp()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "13e3539a-68f2-4293-be7e-601b0e325a21",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([7.2083, 8.5064, 7.8338, 6.0426])"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs.exp().sum(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "30f6e773-7b05-4663-9fb4-aa9336c79915",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[1.2925, 2.2114, 2.5620, 1.1425],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000]],\n",
       "\n",
       "        [[0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [2.5462, 1.8105, 2.3855, 1.7642],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000]],\n",
       "\n",
       "        [[0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [2.0982, 1.5363, 2.4241, 1.7752],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000]],\n",
       "\n",
       "        [[0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [1.3055, 1.8728, 1.3095, 1.5548]]])"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.autograd.functional.jacobian(exp_reducer, inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "dd2ef94a-3952-4be4-b219-6b153a4e6a79",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 4, 4])"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.autograd.functional.jacobian(exp_reducer, inputs).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "bb7b79c0-20ca-40af-aceb-e884c716dafc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[7.2083, 7.2083, 7.2083, 7.2083],\n",
       "        [8.5064, 8.5064, 8.5064, 8.5064],\n",
       "        [7.8338, 7.8338, 7.8338, 7.8338],\n",
       "        [6.0426, 6.0426, 6.0426, 6.0426]])"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.autograd.functional.jacobian(exp_reducer, inputs)[0] @ torch.ones(4, 4) +\\\n",
    "torch.autograd.functional.jacobian(exp_reducer, inputs)[1] @ torch.ones(4, 4) +\\\n",
    "torch.autograd.functional.jacobian(exp_reducer, inputs)[2] @ torch.ones(4, 4) +\\\n",
    "torch.autograd.functional.jacobian(exp_reducer, inputs)[3] @ torch.ones(4, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "561c51d9-1139-484a-b063-9a003c512e0a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[7.2083, 7.2083, 7.2083, 7.2083],\n",
       "        [8.5064, 8.5064, 8.5064, 8.5064],\n",
       "        [7.8338, 7.8338, 7.8338, 7.8338],\n",
       "        [6.0426, 6.0426, 6.0426, 6.0426]])"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.matmul(torch.autograd.functional.jacobian(exp_reducer, inputs), v).sum(dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "791ed8ca-d85f-4e2a-8e0b-5a2a68dccc0e",
   "metadata": {},
   "source": [
    "## multi inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "c8c5dc8a-13fd-4b59-b072-7f56c8f924b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def adder(x, y):\n",
    "    return 2 * x + 3 * y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "4cceb152-ca21-472a-a316-8617c66f7c75",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((tensor([0.3829, 0.9593]), tensor([0.3904, 0.6009])),\n",
       " (tensor([1., 2.]), tensor([2., 3.])))"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = (torch.rand(2), torch.rand(2))\n",
    "v = (torch.tensor([1., 2.]), torch.tensor([2., 3.]))\n",
    "inputs, v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "7cbf37c3-3c7a-4f04-81e7-0c9836cbb41a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([1.9371, 3.7213]), tensor([ 8., 13.]))"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jvp(adder, inputs, v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "cfd4abc1-c4b0-46b7-ab37-ad4151ede20e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2., 4.])"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.autograd.functional.jacobian(lambda x: 2*x, inputs[0]) @ v[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "d980b586-5bd1-43b8-8d1b-6e7c201c3010",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([6., 9.])"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.autograd.functional.jacobian(lambda y: 3*y, inputs[1]) @ v[1]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
